function net = TRLearnPluVel(netname,TrainDS,ValDS,TestDS)
%TransLearnPluVel - CNN transfer learning / training resume
%
%   net = TRLearnPluVel
%   net = TRLearnPluVel(ParamWN)
%   net = TRLearnPluVel(net)
%   net = TRLearnPluVel({net,ParamWN})
%   net = TRLearnPluVel([],TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(ParamWN,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(net,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel({net,ParamWN},TrainDS,ValDS,TestDS)
%
% This function allows the transfer learning of a CNN model or the training
% resume of a CNN. The trained CNN is the output variable net.
%
% If the first argument is undefined or empty, i.e. in the cases
%   net = TRLearnPluVel
%   net = TRLearnPluVel([],TrainDS,ValDS,TestDS),
% a menu box allows the choice between these options:
%   TRANSFER LEARNING FROM A MODEL - DEFAULT PARAMETERS
%   TRANSFER LEARNING FROM A MODEL - MANAGE PARAMETERS FILE
%   RESUME CNN TRAINING - DEFAULT PARAMETERS
%   RESUME CNN TRAINING - MANAGE PARAMETERS FILE
% If the option TRANSFER LEARNING FROM A MODEL - DEFAULT PARAMETERS is 
% chosen, the model and the other parameters carried by ParamWN generated 
% from a session of DefParam are used (please note that a successful call 
% of a model requires the corresponding support package).
% If the option TRANSFER LEARNING FROM A MODEL - MANAGE PARAMETERS FILE is 
% chosen, the model and the other parameters carried by ParamWN from a file
% interactively managed are used (please note that the parameters file must
% have a field named 'ParamWN' and the requirement as above about the 
% support package.  
% If the option RESUME CNN TRAINING - DEFAULT PARAMETERS is chosen, a combo 
% box allows the choice of the file with the CNN whose training should be 
% resumed. Such a file must contain a field whose name is 'net'. Moreover,
% the parameters carried by ParamWN generated by a session of DefParam are
% used.
% If the option RESUME CNN TRAINING - MANAGE PARAMETERS FILE is chosen, two
% combo boxes allow the choices of the file with the CNN whose training 
% should be (compulsory field: 'net') and of the parameters file (compulsory
% field: 'ParamWN')
%
% If the first argument is the struct variable ParamWN, i.e. in the cases
%   net = TRLearnPluVel(ParamWN)
%   net = TRLearnPluVel(ParamWN,TrainDS,ValDS,TestDS),
% a transfer learning is carried out and the model to be used for such a
% learning is taken from ParamWN.
%
% If the first argument is a CNN net, i.e. in the cases
%   net = TRLearnPluVel(net)
%   net = TRLearnPluVel(net,TrainDS,ValDS,TestDS),
% this first argument is expected to be a CNN for the training resume and
% the necessary ParamWN variable is taken from a session of DefParam.
%
% If the first argument is a cell variable, i.e. in the cases
%   net = TRLearnPluVel({net,ParamWN})
%   net = TRLearnPluVel({net,ParamWN},TrainDS,ValDS,TestDS),
% an element of this first argument is expected to be CNN for the training 
% resume and the other argument is either a ParamWN variable or the name 
% of a file having a field named ParamWN.
%
% If the 2nd, 3rd and 4th input arguments are defined, they are assumed
% to be the training, validation and test image datasets respectively.
% The correspoding input options are
%   net = TRLearnPluVel([],TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(modelname,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(net,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel({net,ParamWN},TrainDS,ValDS,TestDS)
% If one or more of these input arguments is undefined or empty, a menu 
% box allows the choice of the corresponding file, which must carry at
% least the three fields TrainDS, ValDS and TestDS (case sensitive).
%
% In the case of transfer learning, the necessary layer changes are 
% automatically carried out taking into account the input image datasets.
%
%   net = TRLearnPluVel
%   net = TRLearnPluVel(ParamWN)
%   net = TRLearnPluVel(net)
%   net = TRLearnPluVel({net,ParamWN})
%   net = TRLearnPluVel([],TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(ParamWN,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel(net,TrainDS,ValDS,TestDS)
%   net = TRLearnPluVel({net,ParamWN},TrainDS,ValDS,TestDS)
%
% See also PluVelInspect, PluVelScalogram, trendClass, dataHomAug, dataTVT.

% G. Teza, 2020

%% Input data management

if nargin < 1 || isempty(netname)
    
    menopTrain = menu ('OPTIONS ON TRAINING:',...
        'TRANSFER LEARNING FROM A MODEL - DEFAULT PARAMETERS',...
        'TRANSFER LEARNING FROM A MODEL - MANAGE PARAMETERS FILE',...
        'RESUME CNN TRAINING - DEFAULT PARAMETERS',...
        'RESUME CNN TRAINING - MANAGE PARAMETERS FILE');
    if menopTrain == 1
        opTrain = 1;    % to have opTrain=1 for transfer learning 
        ParamWN = DefParam;
    elseif menopTrain == 2
        opTrain = 1;
        [filena1,pathna1] = uigetfile('.mat','PARAMETER FILE','ParamWN.mat');
        parfile = fullfile(pathna1,filena1);
        par = load(parfile);
        ParamWN = par.ParamWN;
    else % to have opTrain=1 for TL and opTrain=2 for TR 
        opTrain = 2;
        [filena,pathna] = uigetfile('.mat','CNN file','net.mat');
        netfile = fullfile(pathna,filena);
        net = load(netfile);
        net = net.net;
        if opTrain == 2
            ParamWN = DefParam;
        else
            opTrain = 2;    
            [filena1,pathna1] = uigetfile('.mat','PARAMETER FILE','ParamWN.mat');
            parfile = fullfile(pathna1,filena1);
            par = load(parfile);
            ParamWN = par.ParamWN;
        end
    end
    
elseif isstruct(netname)    % transfer learning
    
    ParamWN = netname; 
    SizeIm = ParamWN.SizeIm;
    nrc = SizeIm(1);

elseif isa(netname,'SeriesNetwork') || isa(netname,'SeriesNetwork') 
        % resume training - option 1
    opTrain = 2;
    net = netname;
    ParamWN = DefParam;
    
elseif iscell(netname)      % resume training - option 2
    
    opTrain = 2;
    if isa(netname(1),'SeriesNetwork') || isa(netname(1),'SeriesNetwork') 
        net = netname(1);
        pof = netname(2);
    else
        net = netname(2);
        pof = netname(1);
    end
    if isstruct(pof)
        ParamWN = pof;
    else
        fop = load(pof);
        ParamWN = fop.ParamWN;
    end
    
else
    
    error('Invalid input');
        
end

if opTrain == 2
    nl = net.Layers;
    iil = nl(1);     % image input layer
    nrc = iil.InputSize(1);
end

setGlobalx(nrc);

if nargin < 4 || isempty(TrainDS) || isempty(ValDS) || isempty(TestDS)
    [filena,pathna] = uigetfile('.mat','Image Datastore file');
    netfile = fullfile(pathna,filena);
    IMDS = load(netfile);
    TrainDS = IMDS.TrainDS;
    ValDS = IMDS.ValDS;
    TestDS = IMDS.TestDS;
end

NL = size(TrainDS.countEachLabel,1);          % number of labels

tTr = TrainDS.countEachLabel;
tVa = ValDS.countEachLabel;
tTe = TestDS.countEachLabel;
fprintf('\nTraining dataset label count:\n\n');
disp(tTr);
fprintf('\nValidation dataset label count:\n\n');
disp(tVa);
fprintf('\nTest dataset label count:\n\n');
disp(tTe);

%% Visualize random images from the set

RIMDS = splitEachLabel(TrainDS,1,'randomize'); % randomization for visualization
figure;
for k = 1:NL
    if NL == 7
        subplot(2,4,k);
    else
        subplot(2,2,k);
    end
    imshow(RIMDS.readimage(k));
    title(char(RIMDS.Labels(k)));
end


%% Set the network fine tuning

MiniBatchSize = ParamWN.MiniBatchSize; % number of images it processes at 
    % once (if the GPU runs out of memory, miniBatchSize should be lowered)
MaxEpochs = ParamWN.MaxEpochs;  % one epoch is one complete pass through the training data

options = trainingOptions('sgdm', ...
    'MiniBatchSize',MiniBatchSize, ...
    'MaxEpochs',MaxEpochs, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',ValDS, ...
    'ValidationFrequency',30, ...
    'ValidationPatience',10, ...
    'Verbose',true, ...
    'Plots','training-progress');

% options = trainingOptions('sgdm', ...
%     'MiniBatchSize',MiniBatchSize, ...
%     'MaxEpochs',MaxEpochs, ...
%     'InitialLearnRate',1e-4, ...
%     'LearnRateSchedule','piecewise',...
%     'LearnRateDropFactor',0.2,...
%     'LearnRateDropPeriod',10,...
%     'Shuffle','every-epoch', ...
%     'ValidationData',ValDS, ...
%     'ValidationFrequency',30, ...
%     'ValidationPatience',10, ...
%     'Verbose',true, ...
%     'Plots','training-progress');

%% Computations

disp("Initialization may take up to a minute before training begins")

if opTrain == 1         % transfer learning
 
    % Model layers:
    layers = net.Layers;
    fprintf('\nCNN layers before transfer learning: \n'); 
    disp(layers);
    
    % Layers changes:
    if ismember(netname,{'alexnet','vgg16','vgg19'})
        layers(end-2) = fullyConnectedLayer(NL,'Name','fc8');
        layers(end) = classificationLayer('Name','NewClassifier');
    elseif strcmp(netname,'resnet18')
        layers(end-2) = fullyConnectedLayer(NL,'Name','fc1000');
        layers(end) = classificationLayer('Name','ClassificationLayer_predictions');
    elseif strcmp(netname,'googlenet')
        layers(end-2) = fullyConnectedLayer(NL,'Name','loss3-classifier');
        layers(end) = classificationLayer('Name','output');
    end
    
    % Setup learning rates for fine-tuning, bump up learning rate for last layers
    layers(end-2).WeightLearnRateFactor = 10;
    layers(end-2).BiasLearnRateFactor = 20;

    tic
    net = trainNetwork(TrainDS,layers,options);
    toc

    layers = net.Layers;
    fprintf('\nCNN layers after transfer learning: \n'); 
    disp(layers);
    
else                    % resume training
    
    try
        layersResume = layerGraph(net);
    catch
        layersResume = net.Layers;
    end
    
    tic
    net = trainNetwork(TrainDS,layersResume,options);
    toc
    
end

%% Test new classifier on test set

fprintf('\nCNN Test in progress... \n'); 
tic
[labels,~] = classify(net,TestDS,'MiniBatchSize',MiniBatchSize);
toc

confMat = confusionmat(TestDS.Labels,labels);
confMat = confMat./sum(confMat,2);

ma = mean(diag(confMat));
fprintf('\nMean accuray from confusion matrix: %3.2f\n',ma);

% Confusion matrix: heat map
tt = table(TestDS.Labels,labels,'VariableNames',{'Actual','Predicted'});
figure; heatmap(tt,'Predicted','Actual');


%% Ancillary functions

% function netname = netNameAcquire
% mennet = menu('CHOICE OF CNN MODEL:',...
%     'alexnet', 'vgg16', 'vgg19', 'googlenet', 'resnet18');
% if mennet == 1
%     netname = 'alexnet';
% elseif mennet == 2
%     netname = 'vgg16';
% elseif mennet == 3
%     netname = 'vgg19';
% elseif mennet == 4
%     netname = 'googlenet';
% else
%     netname = 'resnet18';
% end

function setGlobalx(val)
global x
x = val;